import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.lines as mlines
import seaborn as sns
import numpy as np
import glob

import sys
args = sys.argv
filtering = int(args[1])
# entityfilter, narrativefilter

filtering = 'entityfilter'

in_dir = '../data/human_validation/human_validation_20210603/'

all_paths = glob.glob('{0}*.csv'.format(in_dir))

df = pd.DataFrame()

for f in all_paths:
    df_freelancer = pd.read_csv(f)
    df_freelancer.columns = df_freelancer.columns.str.lower()
    df = df.append(df_freelancer)

df.columns = ['arg', 'arg-raw', 'similarity', 'arg_no', 'true', 'cluster']

# Filter for arguments that appear in frequent narratives
if filtering == 'narrativefilter':
    for c in [100, 500, 1000, 2000]:
        freq = pd.read_csv('../data/gpo_final_data/narratives_complete_with_metadata_{0}.csv'.format(c))
        print(c, len(freq.narrative.unique()))
        # freq = freq[freq.frequency > 150]
        freq = list(freq['ARGO']) + list(freq['ARG1'])
        freq = list(set(freq))
        df = df[(df.cluster != c)|((df.arg.isin(freq)) & (df.cluster == c))]

# Filter for most frequent arguments
if filtering == 'entityfilter':
    for c in [100, 500, 1000, 2000]:
        # freq = pd.read_csv('../output/final_data/narratives_complete_with_metadata_{0}_no_frequency_filter.csv'.format(c))
        freq = pd.read_csv('../data/gpo_final_data/narratives_complete_with_metadata_{0}.csv'.format(c))
        arg0 = freq[['ARGO']]
        arg0.rename(columns={'ARGO': 'ARG'}, inplace=True)
        arg1 = freq[['ARG1']]
        arg1.rename(columns={'ARG1': 'ARG'}, inplace=True)
        freq = arg1.append(arg0)
        freq['n'] = 1
        freq['arg_count'] = freq['n'].groupby(freq['ARG']).transform('sum')
        freq = freq.drop_duplicates()
        top_entities = int(len(freq)*0.25)
        # top_entities = 100
        print(c, top_entities)
        freq = freq.sort_values(by=['arg_count'],ascending=False)[['ARG']].head(n=top_entities)
        freq = list(freq['ARG'])
        df = df[(df.cluster != c)|((df.arg.isin(freq)) & (df.cluster == c))]

print('Annotated narratives after frequency filter:')
print(len(df))

print(df[(df.cluster == 100) & (df.true == 1)].similarity.mean())
print(df[(df.cluster == 100) & (df.true == 0)].similarity.mean())

print(df[(df.cluster == 500) & (df.true == 1)].similarity.mean())
print(df[(df.cluster == 500) & (df.true == 0)].similarity.mean())

print(df[(df.cluster == 1000) & (df.true == 1)].similarity.mean())
print(df[(df.cluster == 1000) & (df.true == 0)].similarity.mean())

print(df[(df.cluster == 2000) & (df.true == 1)].similarity.mean())
print(df[(df.cluster == 2000) & (df.true == 0)].similarity.mean())

plot_data = pd.DataFrame()

plot_data['clusters'] = [100, 500, 1000, 2000]
plot_data['placebo'] = [df[(df.cluster == 100) & (df.true == 0)].similarity.mean(), df[(df.cluster == 500) & (df.true == 0)].similarity.mean(), df[(df.cluster == 1000) & (df.true == 0)].similarity.mean(), df[(df.cluster == 2000) & (df.true == 0)].similarity.mean()]
plot_data['true_clusters'] = [df[(df.cluster == 100) & (df.true == 1)].similarity.mean(), df[(df.cluster == 500) & (df.true == 1)].similarity.mean(), df[(df.cluster == 1000) & (df.true == 1)].similarity.mean(), df[(df.cluster == 2000) & (df.true == 1)].similarity.mean()]
plot_data = plot_data.set_index('clusters')

##############################################################
# Plot similarity by clusters

# Setting labels for plot legend
_legends = list(plot_data)

# Plotting
_fig, _ax = plt.subplots(figsize=(29,15))
_ax.xaxis.label.set_fontsize(25)
_ax.set_xlabel('Clusters')
_ax.yaxis.label.set_fontsize(25)
_ax.set_ylabel('Similarity: 0 (not similar at all) to 10 (very similar/the same)')
_ax.set_facecolor((.18, .31, .31, 0.0))
_ax.grid(color='#dfede2')


_g = sns.lineplot(data=plot_data,
                 markers = True,
                 dashes = False,
                 ax=_ax,
                 markersize=15,
                 linewidth=2.5)
plt.legend(title='Narratives', loc='upper right', labels=_legends, title_fontsize=19, fontsize=15)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

if filtering == 'narrativefilter':
    plt.savefig('../figures/Figure_B_3_a.png')

if filtering == 'entityfilter':
    plt.savefig('../figures/Figure_B_3_b.png')
